package main

func removeElements(head *ListNode, val int) *ListNode {
	if head == nil {
		return nil
	}

	for {
		if head.Val != val {
			break
		}
		head = head.Next
	}

	cur := head
	for cur != nil && cur.Next != nil {
		next := cur.Next

		for next != nil && next.Val == val {
			next = next.Next
		}

		cur.Next = next
		cur = next
	}

	return head
}
